from __future__ import print_function

import torch.nn as nn
import torch.nn.functional as F
import torch

class DistillKL(nn.Module):
    def __init__(self, T, type):
        super(DistillKL, self).__init__()
        self.T = T
        self.kd_type = type

    def forward(self, y_s, y_t):
        if self.kd_type == 'skd':
            output = y_s
            output_t = y_t
            output = F.layer_norm(output, torch.Size((100,)), None, None, 1e-7)* 3
            output_t = F.layer_norm(output_t, torch.Size((100,)), None, None, 1e-7)* 3
            p = F.softmax(output_t / self.T, dim=-1)
            loss = nn.KLDivLoss(reduction="batchmean")(F.log_softmax(output / self.T, dim=1),p
                                                      ) * self.T * self.T
        else:
            p_s = F.log_softmax(y_s/self.T, dim=1)
            p_t = F.softmax(y_t/self.T, dim=1)
            loss = F.kl_div(p_s, p_t, size_average=False) * self.T * self.T / y_s.shape[0]
        return loss